#!/usr/bin/python
# -*- coding: utf-8 -*-

# FIXME stop using the Lattice class, handle " " by multicharacter classes
# FIXME handle cost accounting for multi-character classes correctly
# FIXME different end-of-line handling
# FIXME use argparse for subparsers
# FIXME right now, it can't really "look back" to add rejected characters; they usually fall out of the beam too early

from pylab import *
from collections import Counter,defaultdict
import glob,re,heapq,os,cPickle,codecs
import ocrolib
from ocrolib import ngraphs as ng
from ocrolib.lattice import Lattice2

import argparse


extra = """
Subcommands:

%(prog)s [options] line1.lattice line2.lattice ...

    Compute text output for each lattice file using the given language model.

%(prog)s --sample 20 -l langmod.ngraphs 

    Generate 20 samples from the given language model.

%(prog)s --build output.ngraphs --ngraph 3 textfile1.txt textfile2.txt ...

    Build a language model of order 3 from the given text files.

%(prog)s --print line.lattice

    Visualize the recognition lattice.
"""%dict(prog=sys.argv[0])

parser = argparse.ArgumentParser("""Build, apply, and visualize n-graph language models.""")

# RESULT 0.0338945600778 cweight 1.07777736525 lmodel default-4.ngraphs lweight 0.162668906805 
# maxcost 15.4435581516 maxws 5.39191969674 mismatch 8.67617819976 thresh 1.287178097956


parser.add_argument('--build',default=None,help="build and write a language model")
parser.add_argument('--ngraph',type=int,default=4,help="order of the language model")

parser.add_argument('--sample',default=None,type=int,help="sample from the language model")
parser.add_argument('--slength',default=70,type=int,help="length of the sampled strings")

parser.add_argument('-l','--lmodel',default=ocrolib.default.ngraphs,help="the language model (%(default)s)")

parser.add_argument('-C','--cweight',default=1.0,type=float,help="character weight (%(default)s)")
parser.add_argument('-L','--lweight',default=0.1,type=float,help="language model weight (%(default)s)")
parser.add_argument('-B','--beam',default=10,type=int,help="beam width (%(default)s)")
parser.add_argument('-P','--wsfactor',default=1.0,type=float,help="factor that whitespace costs are multiplied with (%(default)s)")
parser.add_argument('-W','--maxws',default=8,type=float,help="max whitespace cost (%(default)s)")
parser.add_argument('-M','--maxcost',default=15.0,type=float,help="max cost (%(default)s)")
parser.add_argument('-X','--mismatch',default=8,type=float,help="mismatch cost (%(default)s)")
parser.add_argument('-T','--thresh',default=0.0,type=float,help="below this cost, ignore language model")
parser.add_argument('-n','--nbest',type=int,default=5,help="the n best labels from each state to use from the lattice")

parser.add_argument('-q','--quiet',action="store_true",help="don't output each line")
parser.add_argument('--other',default=15.0,type=float,help="extra cost for characters outside the lattice")
parser.add_argument('--nother',default=1,type=int,help="number of candidates from outside the lattice")
parser.add_argument('--lother',default=-1,type=float,help="language model weight for other characters")
parser.add_argument('--debugpaths',action="store_true")
parser.add_argument('--debugstates',default="")
parser.add_argument('--debugmaxrank',type=int,default=4)
parser.add_argument('--detailed',action="store_true")
parser.add_argument('--rewrites',default=None)
parser.add_argument("files",default=[],nargs="*")

args = parser.parse_args()
files = args.files
debugstates = [int(x) for x in args.debugstates.split(",")] if args.debugstates!="" else []

if args.lother<0: args.lother = args.lweight

rewrites = None

if args.rewrites is not None:
    rewrites = defaultdict(list)
    rcost = 1.0
    nrewrites = 0
    with codecs.open(ocrolib.findfile(args.rewrites)) as stream:
        for line in stream.readlines():
            line = line[:-1]
            # print "*",line
            f = line.split("\t")
            assert f[0]=="add"
            rewrites[f[1]].append((f[2],rcost+float(f[3])))
            nrewrites += 1
    print "got",nrewrites,"rewrites"

class Path:
    def __init__(self,cost=0.0,state=-1,path="",sequence=[],labels=[]):
        self.cost = cost # total cost accumulated along this path
        self.state = state # state in the lattice
        self.path = path # current sequence of characters
        self.sequence = sequence # current sequence of states
        self.labels = labels # current sequence of labels (list corresponding to sequence)
    def __repr__(self):
        return "<Path %.2f %d '%s'>"%(self.cost,self.state,self.path)
    def __str__(self):
        return self.__repr__()
    def __cmp__(self,other):
        return cmp((self.cost,self.state,self.path),(other.cost,other.state,other.path))

def rewrite_path(path):
    result = [path]
    for i in range(1,min(4,len(path.path))):
        l = rewrites.get(path.path[-i:],[])
        for o,c in l:
            npath = path.path[:-i]+o
            nlabels = path.labels[:-1]+["_"]
            ncost = path.cost + c
            # print path.path,"->",npath,";",path.cost,ncost
            p = Path(cost=path.cost+c,state=path.state,path=npath,sequence=path.sequence,labels=nlabels)
            result.append(p)
    return result

def expand(path,ngraphs,
           cweight=1.0,lweight=1.0,
           rank=-1,
           verbose=0,
           missing=15.0,
           thresh=1.0,
           nbest=5,
           other=15.0,nother=1,lother=1.0,
           noreject=1):
    """Expand a search path.  Arguments are:

    - `path` the path to be expanded
    - `ngraphs` the ngraph model
    - `rank` the rank of the current path (for debugging)
    - `verbose` display extra information for debugging
    - `missing` the cost of missing characters in the posterior
    - `thresh` the treshold below which the language model cost is ignored entirely
    - `other` the cost for inserting non-lattice characters into the search
    - `nother` the number of non-lattice characters added (top # of characters from posterior)
    - `lother` the language model weight for non-lattice characters
    - `noreject` eliminate reject classes from matching
    """
    ngraphs.missing = {"~":missing}
    floor = missing
    lposteriors = ngraphs.getLogPosteriors(path.path)
    edges = lattice.edges[path.state]
    edges = sorted(edges,key=lambda e:e.cost)
    edges = edges[:nbest]
    result = []
    transitions = set()

    # add all the transitions for which we have edges
    for e in edges:
        if noreject and "~" in e.cls: continue
        assert e.start==path.state

        if e.cls!="" and e.cls!=" ":
            transitions.add((e.start,e.stop))

        # we apply the same string transformation to the predicted classes
        # as to the language model
        cls = ngraphs.lineproc(e.cls)

        # add transitions for single and multi-character classes
        # returned by the classifier
        l = 0.0 if e.cost<thresh and e.cls!=" " else lweight
        if len(cls)==0:
            ncost = path.cost + cweight*e.cost
            # FIXME we really need to add a penalty for not having whitespace here
            if verbose:
                print "EMPTY","ncost",ncost
        elif len(cls)==1:
            lcost = lposteriors.get(cls,floor)
            ncost = path.cost + cweight*e.cost + l*lcost
            if verbose:
                prefix = ngraphs.lineproc(path.path)[-5:]
                print "prefix",repr(prefix),"cls",repr(cls),"ecost",cweight*e.cost,"lcost",lcost,"ncost",ncost,"seg",e.seg
        else:
            ncost = path.cost + cweight*e.cost
            for c in cls:
                tpath = path.path + c
                tcls = tpath[-1]
                lcost = ngraphs.getLogPosteriors(tpath).get(c,floor)
                ncost += l*lcost
            if verbose:
                print "MULTI","prefix",repr(tpath[-10:]),"cls",repr(tcls),repr(cls),
                print "ecost",cweight*e.cost,"lcost",lcost,"ncost",ncost
        nsequence = path.sequence + [e]
        npath = path.path + e.cls
        nstate = e.stop
        nlabels = path.labels + [e.cls]
        assert nstate>path.state,("oops: %s %s %s %s"%(e.start,e.stop,cls,e.cost))
        result.append(Path(cost=ncost,state=nstate,path=npath,sequence=nsequence,labels=nlabels))

    # now add `nother` extra transitions for characters predicted by the language
    # model but not returned by the classifier; this adds the `other` cost
    # to the cost from the language model itself
    
    best = ngraphs.getBestGuesses(path.path,nother=nother)
    for start,stop in transitions:
        for (lcls,lcost) in best:
            ncost = path.cost + other + lcost
            nsequence = path.sequence + [None]
            npath = path.path + lcls
            nstate = stop
            nlabels = path.labels + [lcls]
            if verbose:
                print "OTHER","path",npath[-10:],"lcost",lcost
            result.append(Path(cost=ncost,state=nstate,path=npath,sequence=nsequence,labels=nlabels))

    return result        

def eliminate_common_suffixes_and_sort(paths,n):
    # sort by cost
    paths = sorted(paths)
    # keep track of the best
    result = {}
    for p in paths:
        suffix = p.path[-n:]
        if suffix in result: continue
        result[suffix] = p
    return sorted(result.values())

def search(lattice,ngraphs,accept=None,verbose=0,beam=100,**kw):
    global table
    N = ngraphs.N
    initial = Path(cost=0.0,state=lattice.startState(),path="_"*N)
    nstates = lattice.lastState()+1
    table = [[] for i in range(nstates)]
    table[initial.state] = [initial]

    for i in range(nstates):
        if lattice.isAccept(i): break
        if len(table[i])==0: continue
        table[i] = eliminate_common_suffixes_and_sort(table[i],n=N)

        # now apply the rewrites
        if rewrites is not None:
            npaths = []
            for p in table[i]: npaths += rewrite_path(p)
            table[i] = eliminate_common_suffixes_and_sort(npaths,n=N)

        if args.debugpaths: print i,table[i][0]
        if i in debugstates: print "=== state",i
        for rank,s in enumerate(table[i][:beam]):
            debugexpand = (rank<=args.debugmaxrank and i in debugstates)
            if debugexpand: print "\n--- EXPANDING",rank,s
            expanded = expand(s,ngraphs,rank=rank,verbose=debugexpand,**kw)
            for e in expanded:
                table[e.state].append(e)
        # table[i] = None

    result = eliminate_common_suffixes_and_sort(table[i],n=ngraphs.N)
    return result


if args.build is not None:
    fnames = []
    for pattern in args.files:
        if "=" in pattern:
            fnames += [pattern]
            continue
        l = glob.glob(pattern)
        assert len(l)>0,"%s: didn't expand to any files"%pattern
        for f in l:
            assert ".lattice" not in f
            assert ".png" not in f
        fnames += l
    print "got",len(fnames),"files"
    ngraphs = ng.NGraphs()
    ngraphs.buildFromFiles(fnames,args.ngraph)
    with open(args.build,"w") as stream:
        cPickle.dump(ngraphs,stream,2)
    sys.exit(0)



if args.sample is not None:
    args.lmodel = ocrolib.findfile(args.lmodel)
    print "loading",args.lmodel
    assert os.path.exists(args.lmodel),\
           "%s: cannot find language model"%args.lmodel
    with open(args.lmodel, 'rb') as stream:
        ngraphs = cPickle.load(stream)
    for i in range(args.sample):
        print ngraphs.sample(args.slength)
    sys.exit(0)
    
fnames = []
for pattern in args.files:
    l = sorted(glob.glob(pattern))
    for f in l:
        assert ".lattice" in f,"all files must end with .lattice"
    fnames += l

parser.add_argument('files',nargs='*')
if len(fnames)==0:
    parser.print_help()
    print extra
    sys.exit(0)

if ":" in args.lmodel:
    primary,secondary = args.lmodel.split(":")
    primary = ocrolib.findfile(primary)
    with open(primary, 'rb') as stream:
        primary = cPickle.load(stream)
    secondary = ocrolib.findfile(secondary)
    with open(secondary, 'rb') as stream:
        secondary = cPickle.load(stream)
    ngraphs = ng.NGraphsBackoff(primary,secondary)
else:
    args.lmodel = ocrolib.findfile(args.lmodel)
    print "loading",args.lmodel
    assert os.path.exists(args.lmodel),\
           "%s: cannot find language model"%args.lmodel
    with open(args.lmodel, 'rb') as stream:
        ngraphs = cPickle.load(stream)

def compute_cseg(path,rseg):
    nmax = 10000
    assert amax(rseg)<nmax,"rseg contains too many characters, or there is a bug somewhere"
    mapping = zeros(nmax,'i')
    gt = []
    for i,e in enumerate(path.sequence):
        c = path.labels[i]
        if c=="": continue
        # some of the labels end in " "; we need to separate
        # those spaces from the characters preceding them
        # (otherwise they'd be treated as ligatures)
        sp = ""
        if c[-1]==" ":
            c = c[:-1]
            sp = " "
        if e is None: continue
        if c!="":
            gt.append(c)
            for s in range(e.seg[0],e.seg[1]+1):
                mapping[s] = len(gt)
        if sp!="":
            gt.append(sp)
    return mapping[rseg],gt

print "processing",len(fnames),"files"
for fname in fnames:
    if not args.quiet and not args.detailed: print fname,"=NGRAPHS=",
    lattice = Lattice2(maxws=args.maxws,maxcost=args.maxcost,mismatch=args.mismatch,wsfactor=args.wsfactor)
    lattice.readLattice(fname)

    # search through the lattice for the best path under the ngraph model
    result = search(lattice,ngraphs,lweight=args.lweight,cweight=args.cweight,beam=args.beam,thresh=args.thresh,
                    other=args.other,nother=args.nother,lother=args.lother,nbest=args.nbest)

    # strip the initial context (we prepend "____" to create the line startup context)
    text = result[0].path[ngraphs.N:]

    # output the textual result
    if not args.quiet:
        if args.detailed:
            print "%5.2f %s"%(result[0].cost,fname)
            base,_ = ocrolib.allsplitext(fname)
            if os.path.exists(base+".raw.txt"):
                print "  RAW\t",ocrolib.read_text(base+".raw.txt")
            print "  LMD\t",text
        else:
            print "%5.2f\t%s"%(result[0].cost,text)
    fout = ocrolib.fvariant(fname,"txt")
    ocrolib.write_text(fout,text)

    # write a character segmentation file if there is a raw segmentation
    rname = ocrolib.fvariant(fname,"rseg")
    cname = ocrolib.fvariant(fname,"cseg")
    if os.path.exists(rname):
        rseg = ocrolib.read_line_segmentation(rname)
        cseg,ctxt = compute_cseg(result[0],rseg)
        ocrolib.write_line_segmentation(cname,cseg)
        ocrolib.write_text(ocrolib.fvariant(fname,"aligned"),ocrolib.gt_implode(ctxt))
    else:
        print rname,": not found"

